{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow tutorial\n",
    "\n",
    "In this tutorial we'll show how to build deep learning models in Tribuo, using Tribuo's [TensorFlow](https://tensorflow.org) interface. Tribuo uses [TensorFlow-Java](https://github.com/tensorflow/java) which is build by the TensorFlow [SIG-JVM group](https://github.com/tensorflow/community/blob/master/sigs/jvm/CHARTER.md). Tribuo's development team are active participants in SIG-JVM, and we're trying to make TensorFlow work well for everyone on the Java platform, in addition to making it work well inside Tribuo.\n",
    "\n",
    "Note that Tribuo's TensorFlow interface is not covered by the same stability guarantee as the rest of Tribuo. SIG-JVM has not released a 1.0 version of the TensorFlow Java API, and the API is currently in flux. When TensorFlow Java has API stability we'll be able to stabilize Tribuo's TensorFlow interface to provide the same guarantees as the rest of Tribuo.\n",
    "\n",
    "We're going to train MLPs (Multi-Layer Perceptrons) for classification and regression, along with a CNN (Convolutional Neural Network) for classifying MNIST digits. We'll discuss loading in externally trained TensorFlow models and serving them alongside Tribuo's natively trained models. Finally we'll see how to export TensorFlow models trained in Tribuo into TensorFlow's SavedModelBundle format for interop with TensorFlow Serving and the rest of the TensorFlow ecosystem.\n",
    "\n",
    "Unfortunately TensorFlow-Java has some non-determinism in it's gradient calculations which we're working on fixing in the TensorFlow-Java project, so repeated runs of this notebook will not produce identical answers, which unfortunately breaks some of Tribuo's provenance and reproducibility guarantees. When this is fixed upstream we'll apply any necessary fixes in Tribuo.\n",
    "\n",
    "## Setup\n",
    "\n",
    "You'll need to get a copy of the MNIST dataset in the original IDX format. You may have this from the configuration tutorial, in which case you can skip this step.\n",
    "\n",
    "First the training data:\n",
    "\n",
    "`wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz`\n",
    "\n",
    "Then the test data:\n",
    "\n",
    "`wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz`\n",
    "\n",
    "Tribuo's IDX loader natively reads gzipped files so you don't need to unzip them.\n",
    "\n",
    "We'll also need to download the winequality dataset from UCI. Again, if you've followed the regression tutorial you might already have this, so you can skip this step.\n",
    "\n",
    "`wget https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv`\n",
    "\n",
    "Next we'll load the Tribuo TensorFlow jar and import the packages we'll need for the rest of the tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%jars ./tribuo-tensorflow-4.3.0-jar-with-dependencies.jar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.nio.file.Path;\n",
    "import java.nio.file.Paths;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.tribuo.*;\n",
    "import org.tribuo.data.csv.CSVLoader;\n",
    "import org.tribuo.datasource.IDXDataSource;\n",
    "import org.tribuo.evaluation.TrainTestSplitter;\n",
    "import org.tribuo.classification.*;\n",
    "import org.tribuo.classification.evaluation.*;\n",
    "import org.tribuo.interop.tensorflow.*;\n",
    "import org.tribuo.interop.tensorflow.example.*;\n",
    "import org.tribuo.regression.*;\n",
    "import org.tribuo.regression.evaluation.*;\n",
    "import org.tribuo.util.Util;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.tensorflow.*;\n",
    "import org.tensorflow.framework.initializers.*;\n",
    "import org.tensorflow.ndarray.Shape;\n",
    "import org.tensorflow.op.*;\n",
    "import org.tensorflow.op.core.*;\n",
    "import org.tensorflow.types.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the data\n",
    "\n",
    "This is the same as the configuration and regression tutorials respectively, first we instantiate a `DataSource` for the particular dataset, then feed the data sources into datasets. We'll need to split the wine quality dataset into train & test as it doesn't have a predefined train/test split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "// First we load winequality\n",
    "var regressionFactory = new RegressionFactory();\n",
    "var regEval = new RegressionEvaluator();\n",
    "var csvLoader = new CSVLoader<>(';',regressionFactory);\n",
    "var wineSource = csvLoader.loadDataSource(Paths.get(\"winequality-red.csv\"),\"quality\");\n",
    "var wineSplitter = new TrainTestSplitter<>(wineSource, 0.7f, 0L);\n",
    "var wineTrain = new MutableDataset<>(wineSplitter.getTrain());\n",
    "var wineTest = new MutableDataset<>(wineSplitter.getTest());\n",
    "\n",
    "// Now we load MNIST\n",
    "var labelFactory = new LabelFactory();\n",
    "var labelEval = new LabelEvaluator();\n",
    "var mnistTrainSource = new IDXDataSource<>(Paths.get(\"train-images-idx3-ubyte.gz\"),Paths.get(\"train-labels-idx1-ubyte.gz\"),labelFactory);\n",
    "var mnistTestSource = new IDXDataSource<>(Paths.get(\"t10k-images-idx3-ubyte.gz\"),Paths.get(\"t10k-labels-idx1-ubyte.gz\"),labelFactory);\n",
    "var mnistTrain = new MutableDataset<>(mnistTrainSource);\n",
    "var mnistTest = new MutableDataset<>(mnistTestSource);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining a TensorFlow graph\n",
    "\n",
    "Tribuo's TensorFlow API operates on TensorFlow graphs. You can construct those using TensorFlow's Java API, load in ones already generated by another TensorFlow API, or use one of Tribuo's example graph generators. We're going to define a simple MLP for the wine quality regression task in the notebook, but we'll use Tribuo's example graph generators for classifying MNIST (to make this tutorial a little shorter).\n",
    "\n",
    "TensorFlow Java is working on a higher level layer wise API (similar to [Keras](https://www.tensorflow.org/api_docs/python/tf/keras)), but at the moment we have to define the graph using the low level ops. Once the layer API is available in TensorFlow Java, we'll add entry points so that those APIs can be used with Tribuo, making the next section of this tutorial a lot shorter. For the moment it'll be rather long, but hopefully it's not too hard to follow.\n",
    "\n",
    "Tribuo's TensorFlow trainer will add the appropriate output node, loss function and gradient optimizer, so what you need to supply is the graph which emits the output (before any softmax, sigmoid or other output function), the name of the output op, the names of the input ops and the name of the graph initialization op.\n",
    "\n",
    "## Building a regression model using an MLP\n",
    "\n",
    "To solve this regression task we're going to build a 3 layer neural network, where each layer is a \"dense\" or \"MLP\" layer. We'll use a sigmoid as the activation function, but any supported one in TensorFlow will work. We'll need to know the number of input features and the number of output dimensions (i.e., the number of labels or regression dimensions), which is a little unfortunate as nothing else in Tribuo requires it, but it's required to build the structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "var wineGraph = new Graph();\n",
    "// This object is used to write operations into the graph\n",
    "var wineOps = Ops.create(wineGraph);\n",
    "var wineInputName = \"WINE_INPUT\";\n",
    "long wineNumFeatures = wineTrain.getFeatureMap().size();\n",
    "var wineInitializer = new Glorot<TFloat32>(// Initializer distribution\n",
    "                                           VarianceScaling.Distribution.TRUNCATED_NORMAL,\n",
    "                                           // Initializer seed\n",
    "                                           Trainer.DEFAULT_SEED\n",
    "                                          );\n",
    "\n",
    "// The input placeholder that we'll feed the features into\n",
    "var wineInput = wineOps.withName(wineInputName).placeholder(TFloat32.class,\n",
    "                Placeholder.shape(Shape.of(-1, wineNumFeatures)));\n",
    "                \n",
    "// Fully connected layer (numFeatures -> 30)\n",
    "var fc1Weights = wineOps.variable(wineInitializer.call(wineOps,wineOps.array(wineNumFeatures, 30L),\n",
    "                                                   TFloat32.class));\n",
    "var fc1Biases = wineOps.variable(wineOps.fill(wineOps.array(30), wineOps.constant(0.1f)));\n",
    "var sigmoid1 = wineOps.math.sigmoid(wineOps.math.add(wineOps.linalg.matMul(wineInput, fc1Weights),\n",
    "                                             fc1Biases));\n",
    "\n",
    "// Fully connected layer (30 -> 20)\n",
    "var fc2Weights = wineOps.variable(wineInitializer.call(wineOps,wineOps.array(30L, 20L), \n",
    "                                                       TFloat32.class));\n",
    "var fc2Biases = wineOps.variable(wineOps.fill(wineOps.array(20), wineOps.constant(0.1f)));\n",
    "var sigmoid2 = wineOps.math.sigmoid(wineOps.math.add(wineOps.linalg.matMul(sigmoid1, fc2Weights), \n",
    "                                             fc2Biases));\n",
    "\n",
    "// Output layer (20 -> 1)\n",
    "var outputWeights = wineOps.variable(wineInitializer.call(wineOps,wineOps.array(20L, 1L), \n",
    "                                                          TFloat32.class));\n",
    "var outputBiases = wineOps.variable(wineOps.fill(wineOps.array(1), wineOps.constant(0.1f)));\n",
    "var outputOp = wineOps.math.add(wineOps.linalg.matMul(sigmoid2, outputWeights), outputBiases);\n",
    "\n",
    "// Get the operation name to pass into the trainer\n",
    "var wineOutputName = outputOp.op().name();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can query the operation names by asking the various objects for their `name()`, which Tribuo will use to supply the appropriate inputs and outputs to the graph during training and inference.\n",
    "\n",
    "Now we have the graph, input name, output name and init name, stored in `wineGraph`, `wineInputName`, `wineOutputName` and `wineInitName` respectively. Next we'll define the gradient optimization algorithm and it's hyperparameters. These are separate from Tribuo's built in gradient optimizers as they are part of the TensorFlow native library, but it turns out that most of the same algorithms are available. We're going to use [AdaGrad](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adagrad), set it's learning rate to `0.1f` and the initial accumulator value to `0.01f`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "var gradAlgorithm = GradientOptimiser.ADAGRAD;\n",
    "var gradParams = Map.of(\"learningRate\",0.1f,\"initialAccumulatorValue\",0.01f);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also need to create an object to convert from Tribuo's feature representation to a TensorFlow `Tensor`, and an object that can convert to and from `Tensor` and `Regressor`. These are defined using the `ExampleTransformer` and `OutputTransformer` interfaces. \n",
    "\n",
    "### Converting Features into Tensors with FeatureConverter\n",
    "Tribuo provides two implementations of `FeatureConverter`, one for dense inputs (like those used by MLPs) called `DenseFeatureConverter` and one for image shaped inputs (like those used by CNNs) called `ImageConverter`. If you need more specialised transformations (e.g., text) then you should implement the `FeatureConverter` interface and tailor it to your task's needs. \n",
    "\n",
    "The `FeatureConverter` needs the name of the input placeholder which the features will be fed into, so it can produce the appropriate values in the Map that is fed into the TensorFlow graph.\n",
    "\n",
    "### Converting Outputs into Tensors (and back again) with OutputConverter\n",
    "There are implementations of `OutputConverter` for `Label`, `MultiLabel` and `Regressor`, as those cover the main use cases for TensorFlow. You are free to implement these interfaces for more specialised use cases, though they should be thread-safe and idempotent. The `OutputConverter` contains the loss function and output function which is used to attach the appropriate training hooks to the graph. `LabelConverter` uses the [softmax](https://en.wikipedia.org/wiki/Softmax_function) function to produce probabilistic outputs, and the [Categorical Cross Entropy](https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy) to provide the loss for back-propagation. `RegressorConverter` uses the identity function to produce the output (as it's already producing a real value), and the [Mean-Squared Error](https://www.tensorflow.org/api_docs/python/tf/keras/losses/MeanSquaredError) as the loss function. `MultiLabelConverter` uses an independent sigmoid function for each label as the output, thresholded at 0.5, and [Binary Cross Entropy](https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy) as the loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "var wineDenseConverter = new DenseFeatureConverter(wineInputName);\n",
    "var wineOutputConverter = new RegressorConverter();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We're finally ready to build our first `TensorFlowTrainer`. We need to specify a few more parameters in the constructor, namely the training batch size, the test batch size, and the number of training epochs. We'll set the batch sizes to 16 for all experiments, and we use 100 epochs for the regression task (because it's a small dataset), 20 epochs for the MNIST MLP, and 3 for the MNIST CNN (as the CNN converges much faster than the MLP)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "var wineTrainer = new TensorFlowTrainer<Regressor>(wineGraph,\n",
    "                wineOutputName,\n",
    "                gradAlgorithm,\n",
    "                gradParams,\n",
    "                wineDenseConverter,\n",
    "                wineOutputConverter,\n",
    "                16,   // training batch size\n",
    "                100,  // number of training epochs\n",
    "                16,   // test batch size of the trained model\n",
    "                -1    // disable logging of the loss value\n",
    "                );\n",
    "\n",
    "// Now we close the original graph to free the associated native resources.\n",
    "// The TensorFlowTrainer keeps a copy of the GraphDef protobuf to rebuild it when necessary.\n",
    "wineGraph.close();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TensorFlowTrainer` will accept a `Graph`, a `GraphDef` protobuf, or a path to a `GraphDef` protobuf on disk. The `Graph` should be closed after it's supplied to the trainer, to free the native resources associated with it. Tribuo manages a copy of the `Graph` inside the trainer so users don't need to worry about resource allocation. The trainer automatically adds the loss function, gradient update operations and the final output operation to the supplied graph. \n",
    "\n",
    "We can use this trainer the way we'd use any other Tribuo trainer, we call `trainer.train()` and pass it in a dataset. In the case of TensorFlow it will throw an IllegalArgumentException if the number of features or outputs in the training dataset doesn't match what the trainer is expecting, as those parameters are coupled to the graph structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wine quality training took (00:00:02:255)\n"
     ]
    }
   ],
   "source": [
    "var wineStart = System.currentTimeMillis();\n",
    "var wineModel = wineTrainer.train(wineTrain);\n",
    "var wineEnd = System.currentTimeMillis();\n",
    "System.out.println(\"Wine quality training took \" + Util.formatDuration(wineStart,wineEnd));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we can evaluate it in the same way we evaluate other Tribuo regression models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wine quality evaluation:\n",
      "  RMSE 0.649654\n",
      "  MAE 0.507282\n",
      "  R^2 0.351000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "var wineEvaluation = regEval.evaluate(wineModel,wineTest);\n",
    "var dimension = new Regressor(\"DIM-0\",Double.NaN);\n",
    "System.out.println(String.format(\"Wine quality evaluation:%n  RMSE %f%n  MAE %f%n  R^2 %f%n\",\n",
    "            wineEvaluation.rmse(dimension),\n",
    "            wineEvaluation.mae(dimension),\n",
    "            wineEvaluation.r2(dimension)));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the MLP did ok there, and it's managed to fit the task almost as well as the tree ensemble we showed in the regression tutorial. If we standardised the regression variable, and performed further tuning of the architecture and gradient parameters we could improve on this, but let's move on to classification.\n",
    "\n",
    "## Building a classification model using an MLP\n",
    "\n",
    "Building classification models using the TensorFlow interface is pretty similar to building regression models, thanks to Tribuo's common API for these tasks. The differences come in the choice of `OutputConverter`.\n",
    "\n",
    "We're going to use Tribuo's `MLPExamples` and `CNNExamples` to build the networks for MNIST, as it's a bit shorter. These classes build simple predefined TensorFlow `Graph`s which are useful for demos, Tribuo's tests and getting started with deep learning. Currently there aren't many options in those classes, but we plan to expand them over time, and we welcome community contributions to do so. If you're interested in how the graphs are constructed you can check out the source for them on [GitHub](https://github.com/oracle/tribuo). For complex tasks we recommend that users build their own `Graph`s just as we did in the regression portion of the tutorial. TensorFlow-Java exposes a wide variety of [operations](https://tensorflow.org/jvm) for building graphs, and as the high level API improves it will become easier to specify complex structures.\n",
    "\n",
    "Tribuo's graph building functions return a `GraphDefTuple`, which is a nominal tuple for a `GraphDef` along with the strings representing the necessary operation names. As Tribuo targets Java 8 and upwards it's not a `java.lang.Record`, but it will be one day."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "var mnistInputName = \"MNIST_INPUT\";\n",
    "var mnistMLPTuple = MLPExamples.buildMLPGraph(\n",
    "                        mnistInputName, // The input placeholder name\n",
    "                        mnistTrain.getFeatureMap().size(), // The number of input features\n",
    "                        new int[]{300,200,30}, // The hidden layer sizes\n",
    "                        mnistTrain.getOutputs().size() // The number of output labels\n",
    "                        );\n",
    "var mnistDenseConverter = new DenseFeatureConverter(mnistInputName);\n",
    "var mnistOutputConverter = new LabelConverter();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This built an MLP with 3 hidden layers. The first maps from the feature space to an internal dimension of size 300, then the second is also of size 200, and the third has an internal dimension of 30. Tribuo then adds an output layer mapping down from those 30 dimensions to the 10 output dimensions in MNIST, one per digit.\n",
    "\n",
    "We'll use the same gradient optimiser as before, along with the same hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "var mnistMLPTrainer = new TensorFlowTrainer<Label>(mnistMLPTuple.graphDef,\n",
    "                mnistMLPTuple.outputName, // the name of the logit operation\n",
    "                gradAlgorithm,            // the gradient descent algorithm\n",
    "                gradParams,               // the gradient descent hyperparameters\n",
    "                mnistDenseConverter,      // the input feature converter\n",
    "                mnistOutputConverter,     // the output label converter\n",
    "                16,  // training batch size\n",
    "                20,  // number of training epochs\n",
    "                16,  // test batch size of the trained model\n",
    "                -1   // disable logging of the loss value\n",
    "                );"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we train the model as before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST MLP training took (00:01:08:593)\n"
     ]
    }
   ],
   "source": [
    "var mlpStart = System.currentTimeMillis();\n",
    "var mlpModel = mnistMLPTrainer.train(mnistTrain);\n",
    "var mlpEnd = System.currentTimeMillis();\n",
    "System.out.println(\"MNIST MLP training took \" + Util.formatDuration(mlpStart,mlpEnd));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And evaluate it in the standard way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class                           n          tp          fn          fp      recall        prec          f1\n",
      "0                             980         960          20          94       0.980       0.911       0.944\n",
      "1                           1,135       1,105          30          23       0.974       0.980       0.977\n",
      "2                           1,032         922         110          39       0.893       0.959       0.925\n",
      "3                           1,010         916          94         121       0.907       0.883       0.895\n",
      "4                             982         930          52          61       0.947       0.938       0.943\n",
      "5                             892         813          79          74       0.911       0.917       0.914\n",
      "6                             958         914          44          36       0.954       0.962       0.958\n",
      "7                           1,028         942          86          37       0.916       0.962       0.939\n",
      "8                             974         893          81         107       0.917       0.893       0.905\n",
      "9                           1,009         926          83          87       0.918       0.914       0.916\n",
      "Total                      10,000       9,321         679         679\n",
      "Accuracy                                                                    0.932\n",
      "Micro Average                                                               0.932       0.932       0.932\n",
      "Macro Average                                                               0.932       0.932       0.931\n",
      "Balanced Error Rate                                                         0.068\n",
      "               0       1       2       3       4       5       6       7       8       9\n",
      "0            960       0       1       1       4       3       5       1       2       3\n",
      "1              1   1,105       4       4       0       3       3       1      14       0\n",
      "2             14       5     922      42      12       2       7       4      22       2\n",
      "3             13       0       9     916       1      34       0       9      22       6\n",
      "4              3       0       3       0     930       0      11       3       5      27\n",
      "5             18       0       0      17       5     813       6       4      23       6\n",
      "6             11       4       1       0       5      16     914       0       6       1\n",
      "7             10       7      15       8       4       1       0     942       4      37\n",
      "8             12       2       6      28      10      12       4       2     893       5\n",
      "9             12       5       0      21      20       3       0      13       9     926\n",
      "\n"
     ]
    }
   ],
   "source": [
    "var mlpEvaluation = labelEval.evaluate(mlpModel,mnistTest);\n",
    "System.out.println(mlpEvaluation.toString());\n",
    "System.out.println(mlpEvaluation.getConfusionMatrix().toString());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An MLP works pretty well on MNIST, but when working with images it's usually better to exploit the natural structure, and for that we use a Convolutional Neural Network.\n",
    "\n",
    "## Training a Convolutional Neural Network\n",
    "\n",
    "This is an even smaller transition than the switch between regression and classification. All we need to do is supply a `ImageConverter` which knows the size and pixel depth of the images, and build an appropriately shaped CNN.\n",
    "\n",
    "We'll use `CNNExamples.buildLeNetGraph` to build a version of the venerable [LeNet 5](http://yann.lecun.com/exdb/lenet/) CNN. We specify the image shape (this method assumes images are square), the pixel depth and the number of outputs. So for MNIST that's 28 pixels across, a pixel depth of 255, and 10 output classes one per digit. We'll also need the appropriate `ImageConverter` which needs the name of the input placeholder, the width and height of the image (so allowing rectangular images), and the number of colour channels. MNIST is grayscale, so there's only a single colour channel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "var mnistCNNTuple = CNNExamples.buildLeNetGraph(mnistInputName,28,255,mnistTrain.getOutputs().size());\n",
    "var mnistImageConverter = new ImageConverter(mnistInputName,28,28,1);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can build the trainer and train in the same way as before, but we will train for fewer epochs as the CNN converges faster:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST CNN training took (00:01:59:597)\n"
     ]
    }
   ],
   "source": [
    "var mnistCNNTrainer = new TensorFlowTrainer<Label>(mnistCNNTuple.graphDef,\n",
    "                mnistCNNTuple.outputName, // the name of the logit operation\n",
    "                gradAlgorithm,            // the gradient descent algorithm\n",
    "                gradParams,               // the gradient descent hyperparameters\n",
    "                mnistImageConverter,      // the input feature converter\n",
    "                mnistOutputConverter,     // the output label converter\n",
    "                16, // training batch size\n",
    "                3,  // number of training epochs\n",
    "                16, // test batch size of the trained model\n",
    "                -1  // disable logging of the loss value\n",
    "                );\n",
    "                \n",
    "// Training the model\n",
    "var cnnStart = System.currentTimeMillis();\n",
    "var cnnModel = mnistCNNTrainer.train(mnistTrain);\n",
    "var cnnEnd = System.currentTimeMillis();\n",
    "System.out.println(\"MNIST CNN training took \" + Util.formatDuration(cnnStart,cnnEnd));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And evaluate it the standard way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class                           n          tp          fn          fp      recall        prec          f1\n",
      "0                             980         968          12          22       0.988       0.978       0.983\n",
      "1                           1,135       1,127           8          15       0.993       0.987       0.990\n",
      "2                           1,032       1,011          21          59       0.980       0.945       0.962\n",
      "3                           1,010         959          51          14       0.950       0.986       0.967\n",
      "4                             982         977           5          35       0.995       0.965       0.980\n",
      "5                             892         877          15          54       0.983       0.942       0.962\n",
      "6                             958         934          24           9       0.975       0.990       0.983\n",
      "7                           1,028         981          47          12       0.954       0.988       0.971\n",
      "8                             974         931          43          29       0.956       0.970       0.963\n",
      "9                           1,009         969          40          17       0.960       0.983       0.971\n",
      "Total                      10,000       9,734         266         266\n",
      "Accuracy                                                                    0.973\n",
      "Micro Average                                                               0.973       0.973       0.973\n",
      "Macro Average                                                               0.973       0.973       0.973\n",
      "Balanced Error Rate                                                         0.027\n",
      "               0       1       2       3       4       5       6       7       8       9\n",
      "0            968       0       3       0       0       0       2       2       4       1\n",
      "1              0   1,127       3       0       1       1       1       0       1       1\n",
      "2              3       2   1,011       5       5       0       1       1       4       0\n",
      "3              0       1      14     959       0      28       0       1       5       2\n",
      "4              0       0       1       0     977       0       1       0       0       3\n",
      "5              1       0       1       4       0     877       3       1       4       1\n",
      "6             10       3       1       1       3       5     934       0       1       0\n",
      "7              0       2      25       1       9       2       0     981       3       5\n",
      "8              8       3      10       1       3      11       1       2     931       4\n",
      "9              0       4       1       2      14       7       0       5       7     969\n",
      "\n"
     ]
    }
   ],
   "source": [
    "var cnnPredictions = cnnModel.predict(mnistTest);\n",
    "var cnnEvaluation = labelEval.evaluate(cnnModel,cnnPredictions,mnistTest.getProvenance());\n",
    "System.out.println(cnnEvaluation.toString());\n",
    "System.out.println(cnnEvaluation.getConfusionMatrix().toString());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we might expect, exploiting the structured nature of images lets us get better performance, with ~98% accuracy after only 3 epochs. There is a wide variety of different CNN architectures, each suited for different kinds of tasks. Some are even applied to sequential data like text.\n",
    "\n",
    "## Exporting and Importing TensorFlow models\n",
    "\n",
    "TensorFlow's canonical model storage format is the [`SavedModelBundle`](https://www.tensorflow.org/guide/saved_model). You can export TensorFlow models trained in Tribuo in this format by calling `model.exportModel(String path)` which writes a directory at that path which contains the model as a `SavedModel`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "var outputPath = \"./tf-cnn-mnist-model\";\n",
    "cnnModel.exportModel(outputPath);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tribuo can also load in `SavedModel`s and serve them as an `ExternalModel`. See the external models tutorial for more details on how Tribuo works with models built in other packages. The short version is that you need to specify the mapping from Tribuo's feature names into the id numbers the model expects, and from the output indices to Tribuo's output dimensions. We'll show how to load in the CNN that we just exported, and validate that it gives the same predictions as the original.\n",
    "\n",
    "First we'll setup the feature and output mappings. This is easy in our case as we already have the relevant information, but in most cases this requires understanding how the features were prepared when the original model was trained. We discuss this in more detail in the external models tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "var outputMapping = new HashMap<Label,Integer>();\n",
    "for (var p : cnnModel.getOutputIDInfo()) {\n",
    "    outputMapping.put(p.getB(),p.getA());\n",
    "}\n",
    "var featureIDMap = cnnModel.getFeatureIDMap();\n",
    "var featureMapping = new HashMap<String,Integer>();\n",
    "for (var info : featureIDMap) {\n",
    "    featureMapping.put(info.getName(),featureIDMap.getID(info.getName()));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we build the `TensorFlowSavedModelExternalModel` using it's factory, supplying the feature mapping, output mapping, the softmax output operation name, the image transformer, the label transformer and finally the path to the `SavedModel` on disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "var externalModel = TensorFlowSavedModelExternalModel.createTensorflowModel(\n",
    "                        labelFactory,             // the output factory\n",
    "                        featureMapping,           // the feature mapping\n",
    "                        outputMapping,            // the output mapping\n",
    "                        cnnModel.getOutputName(), // the name of the *softmax* output\n",
    "                        mnistImageConverter,      // the input feature converter\n",
    "                        mnistOutputConverter,     // The label converter\n",
    "                        outputPath.toString()     // path to the saved model\n",
    "                        );"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model behaves like any other, so we pass it some test data and generate it's predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "var externalPredictions = externalModel.predict(mnistTest);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's compare the output predictions. It's a little convoluted, but we're going to compare each predicted probability distribution to make sure they are the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions are equal\n"
     ]
    }
   ],
   "source": [
    "var isEqual = true;\n",
    "for (int i = 0; i < cnnPredictions.size(); i++) {\n",
    "    var tribuo = cnnPredictions.get(i);\n",
    "    var external = externalPredictions.get(i);\n",
    "    isEqual &= tribuo.getOutput().fullEquals(external.getOutput());\n",
    "    isEqual &= tribuo.distributionEquals(external);\n",
    "}\n",
    "System.out.println(\"Predictions are \" + (isEqual ? \"equal\" : \"not equal\"));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the models produce identical predictions, which means that we've successfully exported all our model weights and managed to load them back in as an external model.\n",
    "\n",
    "## Conclusion\n",
    "\n",
    "We saw how to build MLPs and CNNs in Tribuo & TensorFlow for both regression and classification, along with how to export Tribuo-trained models into TensorFlow's format, and import TensorFlow SavedModels into Tribuo.\n",
    "\n",
    "By default Tribuo pulls in the CPU version of TensorFlow Java, but if you supply the GPU jar at runtime it will automatically run everything on a compatible Nvidia GPU. We'll look at exposing explicit GPU support from Tribuo as the relevant support matures in TensorFlow Java."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "17.0.4.1+1-LTS-2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
